/*
 * This file is part of MMM.
 *
 * MMM is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * MMM is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with MMM.  If not, see <http://www.gnu.org/licenses/>.
 *
 * @package    MMM
 * @author     Christian Mandery <mandery@kit.edu>
 * @copyright  2015 High Performance Humanoid Technologies (H2T), Karlsruhe, Germany
 *
 */


#include <VirtualRobot/RuntimeEnvironment.h>
#include <VirtualRobot/RobotNodeSet.h>
#include <VirtualRobot/IK/DifferentialIK.h>
#include <MMM/Motion/Legacy/LegacyMotionReaderXML.h>
#include <MMMSimoxTools/MMMSimoxTools.h>
#include "MMMDynamicsCalculatorConfiguration.h"

using namespace std;

// At the moment, the following two classes derived from MotionFrameEntry are only available here. To allow using the new information in
// the generated XML files using the existing MMM tools, these classes should be placed somewhere else later.
class MotionFrameCoM : public MMM::MotionFrameEntry
{
public:
    MotionFrameCoM() : MotionFrameEntry("CoM") { CoM.setZero(); }

    virtual string toXML()
    {
        string tab = "\t\t\t\t";
        stringstream res;
        res << tab << "<" << tagName << ">" << CoM(0) << " " << CoM(1) << " " << CoM(2) << "</" << tagName << ">" << endl;
        return res.str();
    }

    Eigen::Vector3f CoM;
};

class MotionFrameAngularMomentum : public MMM::MotionFrameEntry
{
public:
    MotionFrameAngularMomentum() : MotionFrameEntry("AngularMomentum") { AM.setZero(); }

    virtual string toXML()
    {
        string tab = "\t\t\t\t";
        stringstream res;
        res << tab << "<" << tagName << ">" << AM(0) << " " << AM(1) << " " << AM(2) << "</" << tagName << ">" << endl;
        return res.str();
    }

    Eigen::Vector3f AM;
};

typedef boost::shared_ptr<MotionFrameCoM> MotionFrameCoMPtr;
typedef boost::shared_ptr<MotionFrameAngularMomentum> MotionFrameAngularMomentumPtr;

struct FrameSegmentData
{
    Eigen::Vector3f position;
    Eigen::Matrix3f rotation;
    Eigen::MatrixXf jacobian;
};

int main(int argc, char *argv[])
{
    MMM_INFO << " --- MMMDynamicsCalculator --- " << endl;
    MMMDynamicsCalculatorConfiguration c;
    if (!c.processCommandLine(argc, argv))
    {
        MMM_ERROR << "Error while processing command line, aborting..." << endl;
        return -1;
    }

    MMM_INFO << "Reading motion file..." << endl;
    MMM::LegacyMotionReaderXMLPtr motionReader(new MMM::LegacyMotionReaderXML());

    vector<string> motionNames = motionReader->getMotionNames(c.inputMotionPath);
    if (motionNames.size() != 1)
    {
        MMM_ERROR << "Input XML file must contain exactly one motion!" << endl;
        return -1;
    }

    MMM::LegacyMotionPtr motion = motionReader->loadMotion(c.inputMotionPath, motionNames[0]);
    if (!motion)
    {
        MMM_ERROR << "Could not load motion!" << endl;
        return -1;
    }

    size_t frameCount = motion->getNumFrames();
    float timeStep = motion->getMotionFrame(1)->timestep - motion->getMotionFrame(0)->timestep;
    MMM_INFO << "Processing motion with " << frameCount << " frames (time step " << timeStep << ")..." << endl;

    // Create Simox objects
    VirtualRobot::RobotPtr robot = MMM::SimoxTools::buildModel(motion->getModel(), false);
    VirtualRobot::RobotNodeSetPtr robotNodeSet = VirtualRobot::RobotNodeSet::createRobotNodeSet(robot, "MMMDynamicsCalculatorRobotNodeSet",
                                                                                                motion->getJointNames(), "", "", true);

    // Determine segments to be used for Angular Momentum calculation
    vector<string> segmentsWithMass;
    vector<VirtualRobot::RobotNodePtr> robotNodes = robot->getRobotNodes();
    for (vector<VirtualRobot::RobotNodePtr>::const_iterator i = robotNodes.begin(); i != robotNodes.end(); ++i)
    {
        string segmentName = (*i)->getName();
        float mass = (*i)->getMass();

        MMM_INFO << "Segment " << segmentName << ": Mass = " << mass << "kg" << endl;

        if (mass > 0)
        {
            segmentsWithMass.push_back(segmentName);
        }
    }

    // Pre-compute segment position, rotation and jacobian for all frames
    MMM_INFO << "Pre-computing data for all segments in all frames..." << endl;
    VirtualRobot::DifferentialIKPtr diffIK(new VirtualRobot::DifferentialIK(robotNodeSet));
    vector<map<string, FrameSegmentData> > frameSegmentData;

    for (size_t frameNumber = 0; frameNumber < frameCount; ++frameNumber)
    {
        MMM::MotionFramePtr motionFrame = motion->getMotionFrame(frameNumber);

        // Set pose and joint values
        robot->setGlobalPose(motionFrame->getRootPose());
        robotNodeSet->setJointValues(motionFrame->joint);

        map<string, FrameSegmentData> fsdMap;

        for (vector<string>::const_iterator curSegment = segmentsWithMass.begin(); curSegment != segmentsWithMass.end(); ++curSegment)
        {
            VirtualRobot::RobotNodePtr segment = robot->getRobotNode(*curSegment);

            FrameSegmentData fsd;
            fsd.position = segment->getCoMGlobal();
            fsd.rotation = segment->getGlobalPose().block(0, 0, 3, 3);
            fsd.jacobian = diffIK->getJacobianMatrix(segment);

            /* MMM_INFO << "Frame " << frameNumber << ", segment " << *curSegment << ": pos = " << fsd.position.transpose() << " rot = "
                     << fsd.rotation.transpose() << " jac = " << fsd.jacobian << endl; */
            fsdMap[*curSegment] = fsd;
        }

        frameSegmentData.push_back(fsdMap);
    }

    MMM_INFO << "Computing CoM..." << endl;
    vector<Eigen::Vector3f> frameCoMs;
    for (size_t frameNumber = 0; frameNumber < frameCount; ++frameNumber)
    {
        MMM::MotionFramePtr motionFrame = motion->getMotionFrame(frameNumber);

        // Set pose and joint values
        robot->setGlobalPose(motionFrame->getRootPose());
        robotNodeSet->setJointValues(motionFrame->joint);

        // CoM
        MotionFrameCoMPtr entryCoM(new MotionFrameCoM());
        entryCoM->CoM = robot->getCoMGlobal();
        MMM_INFO << "Frame " << frameNumber << ": CoM = " << entryCoM->CoM.transpose() << endl;
        motionFrame->addEntry("CoM", entryCoM);
        frameCoMs.push_back(entryCoM->CoM);
    }

    MMM_INFO << "Computing Angular Momentum..." << endl;
    for (size_t frameNumber = 0; frameNumber < frameCount; ++frameNumber)
    {
        // MMM_INFO << "Starting frame " << frameNumber << "..." << endl;

        MMM::MotionFramePtr motionFrame = motion->getMotionFrame(frameNumber);

        // Set pose and joint values
        robot->setGlobalPose(motionFrame->getRootPose());
        robotNodeSet->setJointValues(motionFrame->joint);

        // Angular Momentum
        MotionFrameAngularMomentumPtr entryAngularMomentum(new MotionFrameAngularMomentum());

        for (vector<string>::const_iterator curSegment = segmentsWithMass.begin(); curSegment != segmentsWithMass.end(); ++curSegment)
        {
            VirtualRobot::RobotNodePtr segment = robot->getRobotNode(*curSegment);

            Eigen::Matrix3f robotOrientation = robot->getGlobalPose().block(0, 0, 3, 3);

            // Calculate linear velocity
            Eigen::Vector3f segmentVelocity, CoMVelocity;
            if (frameNumber == 0)
            {
                segmentVelocity = frameSegmentData[1][*curSegment].position - frameSegmentData[0][*curSegment].position;
                CoMVelocity = frameCoMs[1] - frameCoMs[0];
            }
            else if (frameNumber == frameCount - 1)
            {
                segmentVelocity = frameSegmentData[frameCount - 1][*curSegment].position -
                        frameSegmentData[frameCount - 2][*curSegment].position;
                CoMVelocity = frameCoMs[frameCount - 1] - frameCoMs[frameCount - 2];
            }
            else
            {
                segmentVelocity = (frameSegmentData[frameNumber + 1][*curSegment].position -
                        frameSegmentData[frameNumber - 1][*curSegment].position) / 2;
                CoMVelocity = (frameCoMs[frameNumber + 1] - frameCoMs[frameNumber - 1]) / 2;
            }

            Eigen::Vector3f linearVelocity = (segmentVelocity - CoMVelocity) / timeStep;

            // Calculate angular velocity
            Eigen::VectorXf jointVelocity;

            if (frameNumber == 0)
            {
                jointVelocity = motion->getMotionFrame(1)->joint - motion->getMotionFrame(0)->joint;
            }
            else if (frameNumber == frameCount - 1)
            {
                jointVelocity = motion->getMotionFrame(frameCount - 1)->joint - motion->getMotionFrame(frameCount - 2)->joint;
            }
            else
            {
                jointVelocity = (motion->getMotionFrame(frameNumber + 1)->joint - motion->getMotionFrame(frameNumber - 1)->joint) / 2;
            }

            jointVelocity /= timeStep;

            Eigen::Vector3f angularVelocity = (frameSegmentData[frameNumber][*curSegment].jacobian * jointVelocity).segment<3>(3);

            // Calculate inertia tensor
            Eigen::Matrix3f inertiaTensor = segment->getInertiaMatrix();
            inertiaTensor *= 1000000.0f;  // m^2 -> mm^2

            // Add contribution of this segment to angular momentum to sum
            Eigen::Vector3f segmentAM = (robotOrientation.inverse() * (segment->getCoMGlobal() - frameCoMs[frameNumber]).
                                         cross(segment->getMass() * linearVelocity)) + inertiaTensor * angularVelocity;
            entryAngularMomentum->AM += segmentAM;

            // if (*curSegment != "REsegment_joint") continue;  // Reduce debug output
            /* MMM_INFO << "Segment " << *curSegment << ":" << endl;
            MMM_INFO << "  linear velocity = " << linearVelocity.transpose() << endl;
            MMM_INFO << "  joint velocity = " << jointVelocity.transpose() << endl;
            MMM_INFO << "  angular velocity = " << angularVelocity.transpose() << endl;
            MMM_INFO << "  L = " << segmentAM.transpose() << endl; */
        }

        entryAngularMomentum->AM /= 1000000.0f;  // mm^2 -> m^2

        MMM_INFO << "Frame " << frameNumber << ": L = " << entryAngularMomentum->AM.transpose() << endl;
        motionFrame->addEntry("AngularMomentum", entryAngularMomentum);
    }

    MMM_INFO << "Writing motion file..." << endl;
    if (!MMM::XML::saveXML(c.outputMotionPath, motion->toXML()))
    {
        MMM_ERROR << " Could not write to file " << c.outputMotionPath << endl;
        return -1;
    }

    return 0;
}
